Source code for lmcsc.generation

import os
import torch
import torch.distributed as dist
from typing import Tuple, Union, List, Optional
from torch import nn

import warnings

from transformers import (
    BeamScorer,
    LogitsProcessorList,
    StoppingCriteriaList,
)
from transformers.generation.stopping_criteria import validate_stopping_criteria
from transformers.generation.utils import (
    GenerateBeamOutput,
    GenerateBeamEncoderDecoderOutput,
    GenerateBeamDecoderOnlyOutput,
)

from lmcsc.common import HALF_MIN, MIN
from transformers import AutoModelForCausalLM
from lmcsc.obversation_generator import BaseObversationGenerator

[docs] def token_transformation_to_probs(self, observed_sequence: str) -> Tuple[List[int], List[float], dict]: """ Transforms an observed sequence into token indices and their corresponding probabilities. Args: observed_sequence (str): The input sequence to be transformed. Returns: Tuple[List[int], List[float], dict]: A tuple containing: - List of token indices. - List of corresponding probabilities. - Dictionary of original token lengths. """ # Get the token transformation and original token length for the observed sequence token_transformation, original_token_length = self.transformation_type.get_transformation_type(observed_sequence) cache = self.transformation_type_cache indices = list(token_transformation.keys()) def get_weight(trans): # Retrieve the weight from cache if available, otherwise compute and cache it if trans in cache: return cache[trans] else: w = sum(self.distortion_probs[t] for t in trans) cache[trans] = w return w # Compute weights for each token transformation weight = [get_weight(trans) for trans in token_transformation.values()] return indices, weight, original_token_length
[docs] def get_distortion_probs( self, batch_observed_sequences: List[List[str]], eos_token_id: int ) -> Tuple[List[int], List[int], List[int], List[float], List[List[dict]], List[bool]]: """ Computes distortion probabilities for a batch of observed sequences. Args: batch_observed_sequences (List[List[str]]): A batch of observed sequences. eos_token_id (int): The end-of-sequence token ID. Returns: Tuple[List[int], List[int], List[int], List[float], List[List[dict]], List[bool]]: A tuple containing: - List of batch indices. - List of beam indices. - List of token indices. - List of distortion probabilities. - List of original token lengths for each beam. - List of boolean values indicating if EOS is forced. """ cache = self.cache batch_indices, beam_indices, token_indices, distortion_probs = [], [], [], [] force_eos = [] original_token_lengths = [] for batch_index, observed_sequences in enumerate(batch_observed_sequences): beam_original_token_lengths = [] for beam_index, observed_sequence in enumerate(observed_sequences): if observed_sequence in cache: indices, weight, original_token_length = cache[observed_sequence] else: if observed_sequence: indices, weight, original_token_length = self.token_transformation_to_probs(observed_sequence) else: indices = eos_token_id if isinstance(eos_token_id, list) else [eos_token_id] weight = [0.0] * len(indices) original_token_length = {} cache[observed_sequence] = (indices, weight, original_token_length) force_eos.append(len(observed_sequence) == 0) batch_indices.extend([batch_index] * len(indices)) beam_indices.extend([beam_index] * len(indices)) token_indices.extend(indices) distortion_probs.extend(weight) beam_original_token_lengths.append(original_token_length) original_token_lengths.append(beam_original_token_lengths) return batch_indices, beam_indices, token_indices, distortion_probs, original_token_lengths, force_eos
@torch.jit.script def distortion_probs_to_cuda_jit( template_tensor: torch.Tensor, force_eos: torch.Tensor, batch_size: int, num_beams: int, batch_beam_size: int, vocab_size: int, _batch_indices: List[int], _beam_indices: List[int], _token_indices: List[int], _distortion_probs: torch.Tensor) -> torch.Tensor: """ Transfers distortion probabilities to a CUDA tensor. Args: template_tensor (torch.Tensor): The template tensor to be used. force_eos (torch.Tensor): Tensor indicating where to force end-of-sequence. batch_size (int): The size of the batch. num_beams (int): The number of beams. batch_beam_size (int): The size of the batch beam. vocab_size (int): The size of the vocabulary. _batch_indices (List[int]): List of batch indices. _beam_indices (List[int]): List of beam indices. _token_indices (List[int]): List of token indices. _distortion_probs (List[float]): List of distortion probabilities. Returns: torch.Tensor: The resulting tensor with distortion probabilities. """ # Initialize distortion probabilities tensor and mask positions where EOS is forced if template_tensor.dtype == torch.float16: MIN = -1e4 else: MIN = -1e32 distortion_probs = template_tensor.masked_fill(force_eos[:, None], MIN).view(batch_size, num_beams, vocab_size) # Update distortion probabilities with the provided values distortion_probs[_batch_indices, _beam_indices, _token_indices] = _distortion_probs return distortion_probs.view(batch_beam_size, vocab_size)